from torch import nn

from lib.lorentz.layers.Kernels import get_learned_kernels
from lib.lorentz.layers import LorentzFullyConnected
from lib.geoopt import ManifoldParameter

class ClusterDecoder(nn.Module):
    def __init__(self,
                 manifold,
                 input_dim,
                 num_classes,
                 embed_dim=0,
                 learnable=True
                 ):
        super(ClusterDecoder, self).__init__()

        if embed_dim ==0:
            self.embed_dim = input_dim
        else:
            self.embed_dim = embed_dim

        self.centers = ManifoldParameter(
            get_learned_kernels(num_classes, self.embed_dim, 200, manifold),
            manifold,
            requires_grad=learnable)

        self.manifold = manifold
        self.linear = nn.Sequential()

        if embed_dim != 0:
            self.linear = LorentzFullyConnected(self.manifold, input_dim, embed_dim)


    def forward(self, x):

        x = self.linear(x)
        distances = self.manifold.dist(x.unsqueeze(-2), self.centers)

        return distances

